import  os
import 	argparse

args = argparse.ArgumentParser()
args.add_argument('--xml', type=str, default='Annotations')
args.add_argument('--savepath', type=str, default='ImageSets/Main')
args.add_argument('--trainval_ratio', type=float, default=0.9)
args.add_argument('--train_ratio', type=float, default=0.7)

args = args.parse_args()

if __name__ == '__main__':
    total_xml = os.listdir(args.xml)
    total_xml.sort()
    num = len(total_xml)
    trainval_num = int(num * args.trainval_ratio)
    train_num = int(trainval_num * args.train_ratio)

    ftrainval = open(os.path.join(args.savepath,'trainval.txt'), 'w')
    ftest = open(os.path.join(args.savepath,'test.txt'), 'w')
    ftrain = open(os.path.join(args.savepath,'train.txt'), 'w')
    fval = open(os.path.join(args.savepath,'val.txt'), 'w') 

    for i in range(num):
        name = total_xml[i][:-4] + '\n'
        if i < trainval_num:
            ftrainval.write(name)
            if i < train_num:
                ftrain.write(name)
            else:
                fval.write(name)
        else:
            ftest.write(name)
    
    ftrainval.close()
    ftrain.close()
    fval.close()
    ftest.close()